import torch.utils.data as data
import cv2,torch
import pandas as pd
import os
import image_utils
import random
from torchvision import transforms

# 数据预处理，定义data_transform这个字典
data_transform = {
    "train": transforms.Compose([transforms.ToPILImage(),
                                 transforms.Resize((224, 224)),  # 随机裁剪，裁剪到224*224
                                 transforms.RandomHorizontalFlip(),  # 水平方向随机翻转
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.ToPILImage(),
                               transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


class RafDataSet(data.Dataset):
    def __init__(self, raf_path, phase, basic_aug=False):
        self.phase = phase
        self.transform = data_transform[phase]
        self.raf_path = raf_path
        # print(data_transform[phase])

        NAME_COLUMN = 0
        LABEL_COLUMN = 1
        df = pd.read_csv(os.path.join(self.raf_path, 'EmoLabel/list_patition_label.txt'), sep=' ', header=None)
        if phase == 'train':
            dataset = df[df[NAME_COLUMN].str.startswith('train')]
        else:
            dataset = df[df[NAME_COLUMN].str.startswith('test')]
        file_names = dataset.iloc[:, NAME_COLUMN].values
        self.label = dataset.iloc[:,
                     LABEL_COLUMN].values - 1  # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral

        self.file_paths = []
        # use raf-db aligned images for training/testing
        for f in file_names:
            f = f.split(".")[0]
            f = f + "_aligned.jpg"
            path = os.path.join(self.raf_path, 'Image/aligned', f)
            self.file_paths.append(path)

        self.basic_aug = basic_aug
        self.aug_func = [image_utils.flip_image, image_utils.add_gaussian_noise]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        image = cv2.imread(path)
        image = image[:, :, ::-1]  # BGR to RGB
        label = self.label[idx]
        if self.phase == 'train':
            if self.basic_aug and random.uniform(0, 1) > 0.5:
                index = random.randint(0, 1)
                image = self.aug_func[index](image)

        if self.transform is not None:
            image = self.transform(image)

        return image, label, idx

# train_dataset = RafDataSet('../', phase='train', basic_aug=True)
# train_loader = torch.utils.data.DataLoader(train_dataset,
#                                                batch_size=64,
#                                                num_workers=0,
#                                                shuffle=True,
#                                                pin_memory=True)